{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|default_exp models.MultiInputNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MultiInputNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is an implementation created by Ignacio Oguiza (oguiza@timeseriesAI.co).\n",
    "\n",
    "It can be used to combine different types of deep learning models into a single one that will accept multiple inputs from a MixedDataLoaders."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "from tsai.imports import *\n",
    "from tsai.models.layers import *\n",
    "from tsai.models.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#|export\n",
    "class MultiInputNet(Module):\n",
    "    \n",
    "    def __init__(self, *models, c_out=None, reshape_fn=None, multi_output=False, custom_head=None, device=None, **kwargs):\n",
    "        r\"\"\"\n",
    "        Args:\n",
    "            models       : list of models (one model per dataloader in dls). They all must have a head.\n",
    "            c_out        : output layer size.\n",
    "            reshape_fn   : callable to transform a 3d input into a 2d input (Noop, Reshape(-1), GAP1d())\n",
    "            multi_output : determines if the model creates M+1 output (one per model plus a combined one), or just a single output (combined one).\n",
    "            custom_head  : allows you to pass a custom joint head. If None a MLP will be created (you can pass 'layers' to this default head using kwargs)\n",
    "            device       : cpu or cuda. If None, default_device() will be chosen.\n",
    "            kwargs       : head kwargs\n",
    "        \"\"\"\n",
    "\n",
    "        c_out = ifnone(c_out, get_layers(models[0], cond=is_linear)[-1].out_features)\n",
    "        self.M = len(models)\n",
    "        self.m = []\n",
    "        self.backbones = nn.ModuleList()\n",
    "        self.heads = nn.ModuleList()\n",
    "        head_nf = 0\n",
    "        min_nf = np.inf\n",
    "        for i, model in enumerate(models):\n",
    "            try: # if subscriptable\n",
    "                self.heads.append(model[1])\n",
    "                self.backbones.append(model[0])\n",
    "            except:\n",
    "                self.heads.append(model.head)\n",
    "                model.head = Identity()\n",
    "                self.backbones.append(model)\n",
    "            self.m.append(Sequential(self.backbones[-1], self.heads[-1]))\n",
    "            head_nf += model.head_nf\n",
    "            min_nf = min(min_nf, model.head_nf)\n",
    "\n",
    "        self.head_nf = head_nf\n",
    "        if custom_head is None: head = create_fc_head(head_nf, c_out, 1, **kwargs)\n",
    "        else: head = custom_head(self.head_nf, c_out, **kwargs)\n",
    "        self.heads.append(head)\n",
    "        self.multi_output = multi_output\n",
    "        self.m.append(self)\n",
    "        self.reshape = ifnone(reshape_fn, GAP1d())\n",
    "        self.concat = Concat(dim=1)\n",
    "        device = ifnone(device, default_device())\n",
    "        self.to(device=device)\n",
    "\n",
    "    def forward(self, xs):\n",
    "        xs = tuple(*xs) if len(xs) == 1 else xs\n",
    "        out = []\n",
    "        for k in range(self.M):\n",
    "            x = xs[k]\n",
    "            # Create separate features\n",
    "            feat = self.backbones[k](*x) if isinstance(x, (list, tuple, L)) else self.backbones[k](x)\n",
    "\n",
    "            # Process features separately\n",
    "            if self.training and self.multi_output: out.append(self.heads[k](feat))\n",
    "            \n",
    "            # Concat features\n",
    "            if feat.ndim == 3: feat = self.reshape(feat)\n",
    "            concat_feats = feat if k==0 else self.concat([concat_feats, feat])\n",
    "            \n",
    "        # Process joint features\n",
    "        out.append(self.heads[-1](concat_feats))\n",
    "        if self.training and self.multi_output: return out\n",
    "        else:  return out[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tsai.basics import *\n",
    "from tsai.data.all import *\n",
    "from tsai.models.utils import *\n",
    "from tsai.models.InceptionTimePlus import *\n",
    "from tsai.models.TabModel import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Feature Extraction: 100%|███████████████████████████████████████████| 40/40 [00:07<00:00,  5.23it/s]\n"
     ]
    }
   ],
   "source": [
    "#|extras\n",
    "dsid = 'NATOPS'\n",
    "X, y, splits = get_UCR_data(dsid, split_data=False)\n",
    "ts_features_df = get_ts_features(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>epoch</th>\n",
       "      <th>train_loss</th>\n",
       "      <th>valid_loss</th>\n",
       "      <th>accuracy</th>\n",
       "      <th>roc_auc_score</th>\n",
       "      <th>time</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>0</td>\n",
       "      <td>1.780674</td>\n",
       "      <td>1.571718</td>\n",
       "      <td>0.477778</td>\n",
       "      <td>0.857444</td>\n",
       "      <td>00:05</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|extras\n",
    "# raw ts\n",
    "tfms  = [None, [TSCategorize()]]\n",
    "batch_tfms = TSStandardize()\n",
    "ts_dls = get_ts_dls(X, y, splits=splits, tfms=tfms, batch_tfms=batch_tfms)\n",
    "ts_model = build_ts_model(InceptionTimePlus, dls=ts_dls)\n",
    "\n",
    "# ts features\n",
    "cat_names = None\n",
    "cont_names = ts_features_df.columns[:-2]\n",
    "y_names = 'target'\n",
    "tab_dls = get_tabular_dls(ts_features_df, cat_names=cat_names, cont_names=cont_names, y_names=y_names, splits=splits)\n",
    "tab_model = build_tabular_model(TabModel, dls=tab_dls)\n",
    "\n",
    "# mixed\n",
    "mixed_dls = get_mixed_dls(ts_dls, tab_dls)\n",
    "MultiModalNet = MultiInputNet(ts_model, tab_model)\n",
    "learn = Learner(mixed_dls, MultiModalNet, metrics=[accuracy, RocAuc()])\n",
    "learn.fit_one_cycle(1, 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 6])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#|extras\n",
    "(ts, (cat, cont)),yb = mixed_dls.one_batch()\n",
    "learn.model((ts, (cat, cont))).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(6, 6, True)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#|extras\n",
    "tab_dls.c, ts_dls.c, ts_dls.cat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FlattenedLoss of CrossEntropyLoss()"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#|extras\n",
    "learn.loss_func"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": "IPython.notebook.save_checkpoint();",
      "text/plain": [
       "<IPython.core.display.Javascript object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/nacho/notebooks/tsai/nbs/130_models.MultiInputNet.ipynb saved at 2022-11-09 13:16:50\n",
      "Correct notebook to script conversion! 😃\n",
      "Wednesday 09/11/22 13:16:53 CET\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                <audio  controls=\"controls\" autoplay=\"autoplay\">\n",
       "                    <source src=\"data:audio/wav;base64,UklGRvQHAABXQVZFZm10IBAAAAABAAEAECcAACBOAAACABAAZGF0YdAHAAAAAPF/iPh/gOoOon6w6ayCoR2ZeyfbjobxK+F2Hs0XjKc5i3DGvzaTlEaraE+zz5uLUl9f46fHpWJdxVSrnfmw8mYEScqUP70cb0Q8X41uysJ1si6Eh1jYzXp9IE2DzOYsftYRyoCY9dJ/8QICgIcEun8D9PmAaBPlfT7lq4MFIlh61tYPiCswIHX+yBaOqT1QbuW7qpVQSv9lu6+xnvRVSlyopAypbGBTUdSalrSTaUBFYpInwUpxOzhti5TOdndyKhCGrdwAfBUcXIJB69p+Vw1egB76+n9q/h6ADglbf4LvnIHfF/981ODThF4m8HiS0riJVjQ6c+/EOZCYQfJrGrhBmPVNMmNArLKhQlkXWYqhbaxXY8ZNHphLuBJsZUEckCTFVHMgNKGJytIDeSUmw4QN4Qx9pReTgb3vYX/TCBuApf75f+P5Y4CRDdN+B+tngk8c8nt03CKGqipgd13OhotwOC5x9MCAknFFcmlmtPmagFFFYOCo0qRzXMhVi57pryNmIEqJlRi8bm52PfuNM8k4dfQv+4cO12l6zCGdg3jl730uE/KAPvS+f0wEAoAsA89/XfXQgBESIn6S5luDtiC8eh/YmIfpLqt1OMp5jXg8/24MveqUNUnPZsqw0Z3yVDldnaUOqIZfXlKrm36zzWhjRhaT+r+ncHI5/otUzfd2uSt7hl/bqXtoHaCC6+mqfrAOeoDD+PJ/xf8RgLMHfH/b8GeBihZIfSXidoQSJWB52NM1iRkzz3MkxpKPbUCrbDu5d5fgTAxkSK3JoEhYD1p2omere2LZTuqYLbdWa49Cx5Dww7tyXDUnioXRkHhwJyKFvd/AfPoYy4Fl7j1/LQorgEr9/X89+0qAOAwAf13sJoL8Gkd8wt25hWIp3Heez/eKODfPcSPCzpFNRDVqf7UlmnNQKGHgqd+jgVvJVm2f265QZTpLS5byur1tpT6ajvrHq3Q2MXWIxtUCehoj8YMk5LB9hRQegeTypn+nBQWA0QHgf7f2q4C5EFt+5ucOg2YfHXtq2SSHpS0ydnTL4IxFO6pvNb4ulBdInWfcsfSc7VMmXpSmE6eeXmZThJxpsgRohEfOk86+AHCoOpOMFsx1dv8s6oYT2k17uR7ngpXod34IEJqAaPfnfyABCIBZBpl/NPI2gTQVjX134x2ExSPMeR7VtYjZMWJ0W8ftjkA/YW1durCWykvjZFKu4p9LVwVbZKNkqpxh6U+6mRC2mGq2Q3SRvsIgcpc2sIpD0Bp4uiiFhW3ecXxOGgaCDe0Vf4cLPoDv+/5/mfw1gN4KKX+17emBqBmYfBHfVYUZKFR44NBtiv41bHJUwx+RJkP1apu2VJlkTwli4qrwoo1ax1dToNCtemRSTBGXz7kJbdM/PY/Dxht0dTLziH7Ul3loJEiE0uJsfdsVTYGL8Yt/AgcMgHYA7X8S+IqAYA+QfjzpxIIVHnp7tdqzhmAstXaxzEqMETpScGC/dJP3Rmdo8LIZnOVSEF+Opxumsl1sVF+dVrE5Z6NIiZSkvVdv2zsqjdnK8HVDLlyHyNjuegogM4NA5z9+YRG9gA722H97AgOA/gSyf43zCIHdE899yuTIg3ciNXpm1jmImTDwdJPITI4RPhRugbvslbFKt2Vfr/6eTFb4W1WkY6m6YPdQjJr2tNZp3EQlko7BgXHRNz2LAc+gdwMq7IUf3R58ohtFgrbr6n7hDFWAlPr8f/T9I4CECU9/De+vgVQY5nxh4POEzybJeCTS5YnCNAZzhsRzkP1Bsmu4t4aYU07nYuerA6KWWcJYO6HHrKJjaE3Zl624UWz/QOOPjcWHc7QzdIk40yl5tCWjhIDhJX0xF4CBMvBsf10IF4Ac//Z/bPlsgAcOwn6S6n6CwxzUewLcRoYaKzV38M23i9o493CNwL6S1UUuaQe0QpvbUfdfiqglpcRccFU+nkWwambASUiVfLyqbg49xY2eyWh1hy/Sh37XjHpaIYKD7OUEfrgS5IC09MV/1gMBgKMDyH/n9N6AhhINfh7mdoMoIZt6r9fAh1cvfHXNya6N4DzDbqi8K5WWSYlmbbAdnkpV6FxJpWSo1V8DUmGb3rMRaQBG2JJgwN9wCDnNi8HNI3dKK1aG0dvHe/UciIJf6rt+Og5wgDn59X9P/xWAKQhxf2XweYH+FjB9suGVhIMlOnlo02GJhTOdc7vFyo/TQGxs2Li7lz9NwmPurBihnVi7WSWiwKvGYntOpJiOt5drKUKMkFnE8HLxNPmJ9NG4eP8mAYUv4Np8hhi3gdruSX+3CSWAwP38f8f6UoCuDPF+6Os8gnAbKnxQ3d2F0imydzDPKIuiN5lxu8EKkrFE82kftW2az1DbYImpMqTUW3FWIJ83r5hl2koJlla7+m0+PmSOZcjcdMgwS4g11iZ6qCLUg5jkxn0QFA6BWvOvfzEFBIBHAtp/Qfa3gC4RSH5y5yeD2B/8evnYS4cULgR2CMsUja47cG/QvW6UeEhXZ3+xP51GVNVdP6Zpp+1eDFM5nMeySWghR4+TNL85cD46YIyCzKJ2kCzEhoTabXtGHs+CCemJfpMPjoDe9+t/qQALgM8Gj3++8UaBqRV2fQTjO4Q3JKd5r9TgiEYyMHTxxiWPpz8jbfq585YpTJpk960xoKFXsVoTo7yq6GGMTw==\" type=\"audio/wav\" />\n",
       "                    Your browser does not support the audio element.\n",
       "                </audio>\n",
       "              "
      ],
      "text/plain": [
       "<IPython.lib.display.Audio object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#|eval: false\n",
    "#|hide\n",
    "from tsai.export import get_nb_name; nb_name = get_nb_name(locals())\n",
    "from tsai.imports import create_scripts; create_scripts(nb_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
